"""
Phase 1: Data Assembly for Monetary Policy Project
Merges full_panel + trilemma data, adds CBI proxy, IT/QE dates,
constructs interaction and regime variables.
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
FULL_PANEL = PROJECT / "multilateral" / "followup" / "data" / "processed" / "full_panel.csv"
TRILEMMA_PANEL = PROJECT / "trilemma" / "data" / "processed" / "trilemma_panel.csv"
OUT_DIR = PROJECT / "monetary" / "data" / "processed"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
OUT_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(PROJECT / "multilateral" / "src"))

# ── OECD-38 ──────────────────────────────────────────────────────────────────
OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

# ── Inflation Targeting adoption years ───────────────────────────────────────
IT_ADOPTION = {
    "NZL": 1990, "CAN": 1991, "GBR": 1992, "SWE": 1993, "FIN": 1993,
    "AUS": 1993, "ESP": 1995, "ISR": 1997, "CZE": 1998, "POL": 1998,
    "BRA": 1999, "CHL": 1999, "COL": 1999, "ZAF": 2000, "THA": 2000,
    "KOR": 2001, "MEX": 2001, "NOR": 2001, "ISL": 2001, "HUN": 2001,
    "PER": 2002, "PHL": 2002, "GTM": 2005, "IDN": 2005, "ROU": 2005,
    "TUR": 2006, "SRB": 2006, "GHA": 2007, "ARM": 2006, "ALB": 2009,
    "GEO": 2009, "MDA": 2010, "PRY": 2011, "DOM": 2012, "JPN": 2013,
    "RUS": 2015, "IND": 2016, "KAZ": 2015, "UKR": 2016, "JAM": 2018,
    "CRI": 2018, "UZB": 2019,
}

# ── QE episodes ──────────────────────────────────────────────────────────────
QE_EPISODES = {
    "USA": (2008, 2014), "GBR": (2009, 2022), "JPN": (2001, 2006),
    "JPN_2": (2013, 2024),  # second episode merged below
    "SWE": (2015, 2019), "CHE": (2015, 2022),
    # ECB members
    "DEU": (2015, 2022), "FRA": (2015, 2022), "ITA": (2015, 2022),
    "ESP": (2015, 2022), "NLD": (2015, 2022), "BEL": (2015, 2022),
    "AUT": (2015, 2022), "PRT": (2015, 2022), "FIN": (2015, 2022),
    "IRL": (2015, 2022), "GRC": (2015, 2022), "LUX": (2015, 2022),
    "SVK": (2015, 2022), "SVN": (2015, 2022), "EST": (2015, 2022),
    "LVA": (2015, 2022), "LTU": (2015, 2022), "CYP": (2015, 2022),
    "MLT": (2015, 2022),
}
# Merge Japan's two QE episodes
QE_EPISODES["JPN"] = (2001, 2024)
del QE_EPISODES["JPN_2"]

# Post-QE tightening start years
POST_QE_START = {
    "USA": 2022, "GBR": 2022, "DEU": 2022, "FRA": 2022, "ITA": 2022,
    "ESP": 2022, "NLD": 2022, "BEL": 2022, "AUT": 2022, "PRT": 2022,
    "FIN": 2022, "IRL": 2022, "GRC": 2022, "LUX": 2022, "SVK": 2022,
    "SVN": 2022, "EST": 2022, "LVA": 2022, "LTU": 2022, "CYP": 2022,
    "MLT": 2022, "SWE": 2022, "CHE": 2022,
}


def main():
    print("=" * 70)
    print("PHASE 1: DATA ASSEMBLY — MONETARY POLICY PANEL")
    print("=" * 70)

    # ── 1. Load base panels ──────────────────────────────────────────────
    fp = pd.read_csv(FULL_PANEL)
    fp = fp[fp["year"] <= 2024].copy()
    print(f"\nfull_panel: {len(fp):,} obs, {fp['iso3'].nunique()} countries, "
          f"years {fp['year'].min()}-{fp['year'].max()}")

    tri = pd.read_csv(TRILEMMA_PANEL)
    tri = tri[tri["year"] <= 2024].copy()
    # Keep only trilemma-specific columns
    tri_cols = ["iso3", "year", "mi_index", "ers_index", "fo_index",
                "regime_coarse", "regime_3cat", "regime_fine",
                "is_peg", "is_float", "eurozone", "is_oecd"]
    tri_cols = [c for c in tri_cols if c in tri.columns]
    tri = tri[tri_cols].drop_duplicates(subset=["iso3", "year"])
    print(f"trilemma:   {len(tri):,} obs, {tri['iso3'].nunique()} countries")

    # ── 2. Merge ─────────────────────────────────────────────────────────
    panel = fp.merge(tri, on=["iso3", "year"], how="left")
    print(f"\nMerged:     {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    # ── 3. OECD indicator ────────────────────────────────────────────────
    if "is_oecd" not in panel.columns:
        panel["is_oecd"] = panel["iso3"].isin(OECD_38).astype(float)
    else:
        panel["is_oecd"] = panel["is_oecd"].fillna(
            panel["iso3"].isin(OECD_38).astype(float)
        )

    # ── 4. Inflation Targeting dummy ─────────────────────────────────────
    panel["it_adopter"] = panel.apply(
        lambda r: 1.0 if r["iso3"] in IT_ADOPTION and r["year"] >= IT_ADOPTION[r["iso3"]]
        else 0.0, axis=1
    )
    n_it = panel[panel["it_adopter"] == 1]["iso3"].nunique()
    print(f"IT adopters in sample: {n_it} countries")

    # ── 5. QE variables ──────────────────────────────────────────────────
    def qe_active(row):
        iso, yr = row["iso3"], row["year"]
        if iso in QE_EPISODES:
            s, e = QE_EPISODES[iso]
            return 1.0 if s <= yr <= e else 0.0
        return 0.0

    panel["qe_active"] = panel.apply(qe_active, axis=1)
    panel["qe_country"] = panel["iso3"].isin(QE_EPISODES).astype(float)

    def post_qe(row):
        iso, yr = row["iso3"], row["year"]
        if iso in POST_QE_START and yr >= POST_QE_START[iso]:
            return 1.0
        return 0.0

    panel["post_qe_tightening"] = panel.apply(post_qe, axis=1)
    print(f"QE-active obs: {panel['qe_active'].sum():.0f}")
    print(f"Post-QE obs:   {panel['post_qe_tightening'].sum():.0f}")

    # ── 6. CBI proxy — use mi_index (0-1 monetary independence) ──────────
    # Romelli/Garriga CBI not in dataset; mi_index is best available proxy
    panel["cbi_index"] = panel["mi_index"].copy()
    print(f"CBI proxy (mi_index) coverage: {panel['cbi_index'].notna().sum():,}")

    # ── 7. Derived variables ─────────────────────────────────────────────
    # Year-over-year changes
    panel.sort_values(["iso3", "year"], inplace=True)
    for var in ["policy_rate", "inflation", "short_rate_3m", "govt_bond_10y"]:
        if var in panel.columns:
            panel[f"delta_{var}"] = panel.groupby("iso3")[var].diff()

    # Real policy rate
    panel["real_policy_rate"] = panel["policy_rate"] - panel["inflation"]

    # Near ZLB
    panel["near_zlb"] = (panel["policy_rate"] <= 1.0).astype(float)
    panel.loc[panel["policy_rate"].isna(), "near_zlb"] = np.nan

    # Lagged inflation
    panel["inflation_lag"] = panel.groupby("iso3")["inflation"].shift(1)

    # ── 8. Global Z (GDP-weighted world average per year) ────────────────
    for p in [1, 2, 3]:
        zvar = f"Z_{p}"
        w = panel[["year", zvar, "ngdp_usd"]].dropna()
        global_z = (
            w.groupby("year")
            .apply(lambda g: np.average(g[zvar], weights=g["ngdp_usd"]))
            .rename(f"global_{zvar}")
        )
        panel = panel.merge(global_z.reset_index(), on="year", how="left")
        panel[f"domestic_{zvar}_dev"] = panel[zvar] - panel[f"global_{zvar}"]

    # ── 9. Interaction terms ─────────────────────────────────────────────
    interactions = {
        "Z_1_x_output_gap": ("Z_1", "output_gap"),
        "Z_1_x_delta_rate": ("Z_1", "delta_policy_rate"),
        "Z_1_x_cbi": ("Z_1", "cbi_index"),
        "Z_1_x_it": ("Z_1", "it_adopter"),
        "Z_1_x_qe": ("Z_1", "qe_active"),
        "Z_1_x_near_zlb": ("Z_1", "near_zlb"),
        "Z_2_x_output_gap": ("Z_2", "output_gap"),
        "Z_2_x_cbi": ("Z_2", "cbi_index"),
        "Z_2_x_it": ("Z_2", "it_adopter"),
        "Z_2_x_qe": ("Z_2", "qe_active"),
        "Z_3_x_output_gap": ("Z_3", "output_gap"),
        "Z_3_x_cbi": ("Z_3", "cbi_index"),
        "Z_3_x_it": ("Z_3", "it_adopter"),
        "Z_3_x_qe": ("Z_3", "qe_active"),
        "old_dep_x_output_gap": ("old_dep", "output_gap"),
        "working_age_share_x_output_gap": ("working_age_share", "output_gap"),
        "domestic_Z_1_dev_x_kaopen": ("domestic_Z_1_dev", "kaopen"),
    }
    for name, (a, b) in interactions.items():
        if a in panel.columns and b in panel.columns:
            panel[name] = panel[a] * panel[b]

    # ── 10. 5-year lagged Z ─────────────────────────────────────────────
    for p in [1, 2, 3]:
        zvar = f"Z_{p}"
        lagged = panel[["iso3", "year", zvar]].copy()
        lagged["year"] = lagged["year"] + 5
        lagged = lagged.rename(columns={zvar: f"{zvar}_lag5"})
        panel = panel.merge(lagged, on=["iso3", "year"], how="left")

    # ── 11. First differences of Z ───────────────────────────────────────
    for p in [1, 2, 3]:
        panel[f"dZ_{p}"] = panel.groupby("iso3")[f"Z_{p}"].diff()

    # ── 12. Cumulative forward growth/inflation for local projections ────
    for h in range(1, 6):
        for var, prefix in [("rgdp_growth", "cum_growth"), ("inflation", "cum_inflation")]:
            fwd = panel[["iso3", "year", var]].copy()
            fwd["year"] = fwd["year"] - h
            fwd = fwd.rename(columns={var: f"{prefix}_{h}"})
            panel = panel.merge(fwd, on=["iso3", "year"], how="left")
        # Make cumulative (sum over horizons 1..h)
        if h > 1:
            panel[f"cum_growth_{h}"] = (
                panel[[f"cum_growth_{i}" for i in range(1, h + 1)]]
                .sum(axis=1, min_count=h)
            )
            panel[f"cum_inflation_{h}"] = (
                panel[[f"cum_inflation_{i}" for i in range(1, h + 1)]]
                .sum(axis=1, min_count=h)
            )

    # ── 13. Income terciles ──────────────────────────────────────────────
    gdp_med = panel.groupby("iso3")["gdp_pc_ppp"].median()
    t1, t2 = gdp_med.quantile(0.33), gdp_med.quantile(0.67)
    income_map = gdp_med.apply(
        lambda x: "low" if x <= t1 else ("middle" if x <= t2 else "high")
    )
    panel["income_tercile"] = panel["iso3"].map(income_map)

    # ── 14. Predetermined instrument ─────────────────────────────────────
    # oadr_plus20 already in full_panel

    # ── Save ─────────────────────────────────────────────────────────────
    out_path = OUT_DIR / "monetary_panel.csv"
    panel.to_csv(out_path, index=False)
    print(f"\nSaved: {out_path}")
    print(f"  {len(panel):,} obs, {panel['iso3'].nunique()} countries, "
          f"{panel.shape[1]} columns")

    # ── Summary statistics ───────────────────────────────────────────────
    key_vars = [
        "Z_1", "Z_2", "Z_3", "inflation", "policy_rate", "real_policy_rate",
        "govt_bond_10y", "short_rate_3m", "real_bond_10y", "real_short_3m",
        "term_spread", "output_gap", "rgdp_growth", "fiscal_bal_gdp",
        "kaopen", "nfa_gdp_lag", "cbi_index", "it_adopter", "qe_active",
        "near_zlb", "old_dep", "youth_dep", "working_age_share",
        "delta_policy_rate", "delta_inflation",
        "global_Z_1", "domestic_Z_1_dev",
    ]
    key_vars = [v for v in key_vars if v in panel.columns]
    stats = panel[key_vars].describe().T
    stats["coverage"] = panel[key_vars].notna().sum()

    stats_path = TABLE_DIR / "phase1_summary_stats.csv"
    stats.to_csv(stats_path)
    print(f"Saved: {stats_path}")

    # Markdown summary
    md_lines = ["# Phase 1: Summary Statistics\n"]
    md_lines.append(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries, "
                    f"years {panel['year'].min()}-{panel['year'].max()}\n")
    md_lines.append("| Variable | N | Mean | Std | Min | Max |")
    md_lines.append("|----------|---|------|-----|-----|-----|")
    for var in key_vars:
        s = panel[var].dropna()
        md_lines.append(
            f"| {var} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} "
            f"| {s.min():.3f} | {s.max():.3f} |"
        )
    md_path = TABLE_DIR / "phase1_summary_stats.md"
    md_path.write_text("\n".join(md_lines))
    print(f"Saved: {md_path}")

    # Coverage report
    print("\n── Key Variable Coverage ──")
    for v in ["policy_rate", "govt_bond_10y", "inflation", "output_gap",
              "cbi_index", "it_adopter", "qe_active"]:
        if v in panel.columns:
            n = panel[v].notna().sum()
            nc = panel.loc[panel[v].notna(), "iso3"].nunique()
            print(f"  {v:25s}: {n:6,} obs, {nc:3d} countries")

    print("\nPhase 1 complete.")


if __name__ == "__main__":
    main()
